import itertools
from collections import defaultdict
from pathlib import Path
import argparse
import os
import json
import torch
import numpy as np

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

from ase.neighborlist import natural_cutoffs, NeighborList
from ase.io import read, Trajectory

# optional. nglview for visualization
# import nglview as nv

"""
functions for loading simulated trajectories and computing observables.
"""

def get_thermo(filename):
    """
    read thermo logs.
    """
    with open(filename, 'r') as f:
        thermo = f.read().splitlines()
        sim_time, Et, Ep, Ek, T = [], [], [], [], []
        for i in range(1, len(thermo)):
            try:
                t, Etot, Epot, Ekin, Temp = [float(x) for x in thermo[i].split(' ') if x]
                sim_time.append(t)
                Et.append(Etot)
                Ep.append(Epot)
                Ek.append(Ekin)
                T.append(Temp)
            except:
                sim_time, Et, Ep, Ek, T = [], [], [], [], []
    thermo = {
        'time': sim_time,
        'Et': Et,
        'Ep': Ep,
        'Ek': Ek,
        'T': T
    }
    return thermo

def get_test_metrics(md_dir):
    """
    read test metrics such as force error.
    """
    run_metrics = {}
    with open(md_dir / 'test_metric.json', 'r') as f:
        test_metric = json.load(f)
        
        if 'mae_f' in test_metric:
            fmae = test_metric['mae_f']
            run_metrics['fmae'] = fmae
        elif 'f_mae' in test_metric:
            fmae = test_metric['f_mae']
            run_metrics['fmae'] = fmae
        elif 'forces_mae' in test_metric:
            fmae = test_metric['forces_mae']['metric']
            run_metrics['fmae'] = fmae
            
        if 'mae_e' in test_metric:
            emae = test_metric['mae_e']
            run_metrics['emae'] = emae
        elif 'e_mae' in test_metric:
            emae = test_metric['e_mae']
            run_metrics['emae'] = emae
        elif 'energy_mae' in test_metric:
            emae = test_metric['energy_mae']['metric']
            run_metrics['emae'] = emae
            
        if 'num_params' in test_metric:
            run_metrics['n_params'] = test_metric['num_params']
        if 'running_time' in test_metric:
            run_metrics['running_time'] = test_metric['running_time']
    return run_metrics

def mae(x, y, factor):
    return np.abs(x-y).mean() * factor

def distance_pbc(x0, x1, lattices):
    delta = torch.abs(x0 - x1)
    lattices = lattices.view(-1,1,3)
    delta = torch.where(delta > 0.5 * lattices, delta - lattices, delta)
    return torch.sqrt((delta ** 2).sum(dim=-1))

def get_diffusivity_traj(pos_seq, dilation=1):
    """
    Input: B x N x T x 3
    Output: B x T
    """
    # substract CoM
    bsize, time_steps = pos_seq.shape[0], pos_seq.shape[2]
    pos_seq = pos_seq - pos_seq.mean(1, keepdims=True)
    msd = (pos_seq[:, :, 1:] - pos_seq[:, :, 0].unsqueeze(2)).pow(2).sum(dim=-1).mean(dim=1)
    diff = msd / (torch.arange(1, time_steps)*dilation) / 6
    return diff.view(bsize, time_steps-1)

def get_smoothed_diff(xyz):
    seq_len = xyz.shape[0] - 1
    diff = torch.zeros(seq_len)
    for i in range(seq_len):
        diff[:seq_len-i] += get_diffusivity_traj(xyz[i:].transpose(0, 1).unsqueeze(0)).flatten()
    diff = diff / torch.flip(torch.arange(seq_len),dims=[0])
    return diff

def get_hr(traj, bins):
    """
    compute h(r) for MD17 simulations.
    traj: T x N_atoms x 3
    """
    pdist = torch.cdist(traj, traj).flatten()
    hist, _ = np.histogram(pdist[:].flatten().numpy(), bins, density=True)
    return hist

def load_run(md_dir, xlim, bins, stability_threshold, gt_traj, hist_gt):
    """
    md_dir: directory to the finished MD simulation.
    """
    if not isinstance(md_dir, Path):
        md_dir = Path(md_dir)
    
    model_name = md_dir.parts[-2]
    seed = md_dir.parts[-1][-1]
    run = {'name': (model_name + f'_seed_{seed}'),}
    
    # get bonds
    traj = Trajectory(md_dir / 'atoms.traj')
    atoms = traj[0]
    # print("atoms: ",atoms)
    NL = NeighborList(natural_cutoffs(atoms), self_interaction=False)
    NL.update(atoms)
    bonds = NL.get_connectivity_matrix().todense().nonzero()
    bonds = torch.tensor(bonds)
    # print("bonds: ",bonds.shape,bonds)

    # process trajectory
    traj = [x.positions for x in traj]
    # print("traj: ",len(traj),traj[0:10])
    run['traj'] = torch.from_numpy(np.stack(traj))
    # print("run['traj']= ",run['traj'].shape)
    run['traj'] = torch.unique(run['traj'], dim=0) # remove repeated frames from restarting.

    # load thermo log
    run['thermo'] = get_thermo(md_dir / 'thermo.log')
    T = np.array(run['thermo']['T'])
    collapse_pt = len(T)
    md_time = np.array(run['thermo']['time'])

    # track stability
    bond_lens = distance_pbc(
        gt_traj[:, bonds[0]], gt_traj[:, bonds[1]], torch.FloatTensor([30., 30., 30.]))
    mean_bond_lens = bond_lens.mean(dim=0)
    
    for i in range(1, len(T)):
        bond_lens = distance_pbc(
            run['traj'][(i-1):i, bonds[0]], run['traj'][(i-1):i, bonds[1]], torch.FloatTensor([30., 30., 30.]))
        max_dev = (bond_lens[0] - mean_bond_lens).abs().max()
        run['max_dev'] = max_dev
        if  max_dev > stability_threshold:
            collapse_pt = i
            break
    run['collapse_pt'] = collapse_pt
    
    # compute h(r)
    hist_pred = get_hr(run['traj'][0:collapse_pt], bins)
    hr_mae = mae(hist_pred, hist_gt, xlim)
    run['hr'] = hist_pred
    run['hr_error'] = hr_mae
    
    # load test metrics
    if (md_dir / 'test_metric.json').exists():
        test_metrics = get_test_metrics(md_dir)
        run.update(test_metrics)
    
    return run

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--molecule", required=True, type=str)
    parser.add_argument("--model_dir", type=str)
    parser.add_argument("--identifier", type=str)
    parser.add_argument("--save_freq", type=int)
    parser.add_argument("--steps", type=int)

    args, override_args = parser.parse_known_args()

    # eval parameters
    stability_threshold = 0.5
    xlim = 10
    n_bins = 500
    bins = np.linspace(1e-6, xlim, n_bins + 1) # for computing h(r)

    # select molecule and get ground truth data
    molecule = args.molecule
    DATAPATH = f'dataset/md17_{molecule}.npz'
    gt_data = np.load(DATAPATH)
    gt_traj = torch.FloatTensor(gt_data.f.R) # (211762, 21, 3)     
    print(gt_traj.shape) 
    hist_gt = get_hr(gt_traj, bins)

    # load run and plot h(r)
    stability_threshold = 0.5
    md_dir = Path(args.model_dir)
    run = load_run(md_dir, xlim, bins, stability_threshold, gt_traj, hist_gt)
    plt.plot(bins[1:], hist_gt, label='Reference', linewidth=2, linestyle='--')
    plt.plot(bins[1:], run['hr'], label='Prediction', linewidth=2, linestyle='--')
    plt.xlabel('r')
    plt.ylabel('h(r)')
    plt.legend() 

    # metrics
    force_mae = run['fmae'] * 1000
    collapse_ps = (run['collapse_pt']-1) / 20
    hr_mae = run['hr_error'] 
    print(f'force mae: {force_mae:.1f} meV/A \nstability: {collapse_ps:.1f} ps \nh(r) mae: {hr_mae:.2f}')
    print(run['max_dev'])